package gateway

import (
	"context"
	"errors"
	core "ez/apps/core/document"
	"ez/apps/core/service"
	"fmt"
	"gitee.com/dreamwood/ez-go/ez"
	"gitee.com/dreamwood/ez-go/ss"
	"gitee.com/dreamwood/ez-go/tools"
	"math/rand"
	"net/http"
	"net/http/httputil"
	"sync"
)

// GateWays 网关全局变量，键为监听的端口,值为此端口下启动的http.Server
var GateWays map[int64]*http.Server

// Hosts 主机全局变量 键为主机的AppId,值为此ID下的所有负载主机
var Hosts map[string][]*core.Host

// Routes 路由全局变量
var Routes []*core.Route

var mt sync.Mutex

func InitGateWay() {
	GateWays = make(map[int64]*http.Server)
	Hosts = make(map[string][]*core.Host)
	Routes = make([]*core.Route, 0)
	mt = sync.Mutex{}
	ez.LogToConsoleNoTrace("开始准备GateWay主机数据")
	PrepareHost()
	ez.LogToConsoleNoTrace("开始准备GateWay路由转发数据")
	PrepareRoute()
	ez.LogToConsoleNoTrace("开始准备AccessControl数据")
	service.CreateApiRoles()
	ez.LogToConsoleNoTrace("开始启动端口监听")
	AutoStartGateWay()
}

func init() {
	ez.Subscribe(ez.EventAfterServerRun, func(v interface{}, ctx context.Context) {
		InitGateWay()
	})
	ez.Subscribe("HostChanged", func(v interface{}, ctx context.Context) {
		PrepareHost()
	})
}

func GetLock() *sync.Mutex {
	return &mt
}

func PrepareHost() {
	crud := core.NewHostCrud()
	hosts, e := crud.FindBy(ss.M{"isOn": true, "state": 1}, nil, 0, 0)
	if ez.Try(e) {
		return
	}
	//对HostsMap操作加锁
	mt.Lock()
	defer mt.Unlock()
	//先清空现有的数据
	Hosts = make(map[string][]*core.Host)
	for _, host := range hosts {
		find, ok := Hosts[host.AppId]
		if !ok {
			find = make([]*core.Host, 0)
		}
		find = append(find, host)
		Hosts[host.AppId] = find
	}
}

func PrepareRoute() {
	routes, e := core.NewRouteCrud().
		FindBy(ss.M{"isOn": true}, []string{"appId"}, 0, 0)
	if ez.Try(e) {
		return
	}
	mt.Lock()
	defer mt.Unlock()
	Routes = make([]*core.Route, 0)
	for _, route := range routes {
		Routes = append(Routes, route)
	}
}

func CreateDirector(gateWay *core.GateWay) func(*http.Request) {
	var Director = func(request *http.Request) {

		request.Header.Set("ChainKey", tools.CreateRandString(32))

		//权限检查
		isAllow := service.CheckAccessByRequest(request)
		if !isAllow {
			request.URL.Path = "/_server_403"
			request.URL.RawQuery = "code=ACCESS_DENY"
			request.URL.Host = fmt.Sprintf("%s:%d", ez.ConfigServer.ServerHost, ez.ConfigServer.ServerPort)
			request.URL.Scheme = "http"
			return
		}
		urlRedirect := ""
		var maxSort int64 = -1
		maxSortAppId := ""
		for _, route := range Routes {
			if route.GateWayId != gateWay.Id {
				continue
			}
			pattern := tools.ReplaceAll(route.From, `:(\w+)\((.+?)\)`, "(?P<$1>$2)")
			matches := tools.MatchAll(request.URL.Path, pattern)
			if len(matches) > 0 {
				if route.Sort > maxSort {
					urlRedirect = tools.ReplaceAll(request.URL.Path, pattern, route.To)
					maxSort = route.Sort
					maxSortAppId = route.AppId
				}
			}
		}
		//取出所有负载host
		hosts, ok := Hosts[maxSortAppId]
		if !ok {
			//没有相应的主机
			ez.LogToConsoleNoTrace(fmt.Sprintf("主机未找到:%s", maxSortAppId))
			request.URL.Path = "/_server_404"
			request.URL.RawQuery = "code=NO_HOST"
			request.URL.Host = fmt.Sprintf("%s:%d", ez.ConfigServer.ServerHost, ez.ConfigServer.ServerPort)
			request.URL.Scheme = "http"
			//尝试重启
			go PrepareHost()
			return
		}
		appHosts := make([]*core.Host, 0)
		for _, appHost := range hosts {
			if appHost.IsOn {
				appHosts = append(appHosts, appHost)
			}
		}
		//找到了匹配的应用
		targetHost := "127.0.0.1"
		if len(appHosts) == 1 {
			targetHost = fmt.Sprintf("%s:%d", appHosts[0].Ip, appHosts[0].Port)
		} else if len(appHosts) > 1 {
			rndHost := getHost(hosts)
			targetHost = fmt.Sprintf("%s:%d", rndHost.Ip, rndHost.Port)
		} else {
			ez.LogToConsole("主机全部掉线")
			request.URL.Path = "/_server_500"
			request.URL.RawQuery = "code=NO_HOST"
			request.URL.Host = fmt.Sprintf("%s:%d", ez.ConfigServer.ServerHost, ez.ConfigServer.ServerPort)
			request.URL.Scheme = "http"
		}
		if targetHost[len(targetHost)-3:] == ":80" {
			targetHost = targetHost[:len(targetHost)-3]
		}

		if !service.CheckIsIgnored(request.URL.Path) {
			ez.LogToConsoleNoTrace("From:", request.URL.Path, "To:", urlRedirect, "Host:", targetHost)
		}

		request.URL.Host = targetHost
		request.URL.Path = urlRedirect
		request.URL.Scheme = "http"
	}
	return Director
}

func Responsor(response *http.Response) error {
	if response.Request.Method == "OPTIONS" {
		return nil
	}
	//检查是否需要更新
	response.Header.Set("Access-Control-Allow-Origin", "*")
	if response.Request.Header.Get(ez.ConfigApi.SessionTokenName) == "" {
		response.Header.Set(fmt.Sprintf("set-%s", ez.ConfigApi.SessionTokenName), tools.CreateRandString(64))
	}

	//Auth-Token 检查和延期
	if authToken := service.CheckAndUpdateUserToken(response); authToken != "" {
		response.Header.Set(fmt.Sprintf("set-%s", ez.ConfigApi.AuthTokenName), authToken)
	}
	//注意，这里下面的代码有可能会导致websocket的转发出现问题，
	//如果需要改写输出内容，记得对websocket做特殊处理
	//respBody, _ := io.ReadAll(response.Body)
	//copyBody := io.NopCloser(bytes.NewReader(respBody))
	//response.Body = copyBody
	return nil
}

func StartGateWay(gateway *core.GateWay) {
	proxyConf := httputil.ReverseProxy{
		Director:       CreateDirector(gateway),
		ModifyResponse: Responsor,
	}
	ez.LogToConsoleNoTrace(fmt.Sprintf("路由转发中心配置完成,监听端口:%d", gateway.Port))
	existedGtw, ok := GateWays[gateway.Port]
	if ok {
		e := existedGtw.Shutdown(context.TODO())
		ez.PrintError(e)
	}
	srv := &http.Server{
		Addr:    fmt.Sprintf("%s:%d", gateway.Ip, gateway.Port),
		Handler: &proxyConf,
	}
	GateWays[gateway.Port] = srv
	err := srv.ListenAndServe()
	if err != nil {
		ez.Debug(fmt.Sprintf("路由转发中心端口监听时发生错误:%s", err.Error()))
	}
}

func StopGateWay(port int64) (err error) {
	find, ok := GateWays[port]
	if !ok {
		err = errors.New("网关已不存在")
		return
	}
	err = find.Shutdown(context.TODO())
	delete(GateWays, port)
	return
}

func AutoStartGateWay() {
	gateWays, _ := core.NewGateWayCrud().FindBy(ss.M{}, nil, 0, 0)
	for _, gateway := range gateWays {
		if gateway.IsOn {
			go StartGateWay(gateway)
		}
	}
}

// 负载均衡取出一个host
func getHost(hosts []*core.Host) *core.Host {
	if len(hosts) == 0 {
		return nil
	} else {
		max := int64(0)
		for _, host := range hosts {
			max += host.Weight
		}
		rnd := rand.Int63n(max)
		max = 0
		for _, host := range hosts {
			max += host.Weight
			if rnd < max {
				return host
			}
		}
		return hosts[0]
	}
}
